import re
import os
import warnings
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

warnings.filterwarnings("ignore")

import torch
import pickle
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from guidance import models, gen

device = "cpu" #"cuda:0"

MODEL_DIR = "/home/vcollura/TRIDENT/ilm-master/models/sto_ilm"

def load_model_and_tokenizer(model_dir: str, device: str = None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    config_path = os.path.join(model_dir, "config.json")
    config = GPT2Config.from_json_file(config_path)

    model = GPT2LMHeadModel(config)
    state_dict = torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location="cpu")
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    if missing:
        print(f"⚠️  Missing keys: {missing}")
    if unexpected:
        print(f"⚠️  Unexpected keys: {unexpected}")

    model.to(device)
    model.eval()

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    pkl_path = os.path.join(model_dir, "additional_ids_to_tokens.pkl")
    new_tokens = []
    if os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f:
            additional_tokens = pickle.load(f)

        if isinstance(additional_tokens, dict):
            new_tokens = list(additional_tokens.values())
        elif isinstance(additional_tokens, list):
            new_tokens = additional_tokens
        else:
            raise ValueError("additional_ids_to_tokens.pkl")

        if new_tokens:
            tokenizer.add_tokens(new_tokens)
            model.resize_token_embeddings(len(tokenizer))
            for token in new_tokens:
                token_id = tokenizer.convert_tokens_to_ids(token)
                print(f"Token: {token} -> ID: {token_id}")
    else:
        print("no additional_ids_to_tokens.pkl.")

    return model, tokenizer, device

model_dir = "" # model dir
hf_model, tokenizer, device = load_model_and_tokenizer(model_dir, device)

gpt2 = models.Transformers(model=hf_model, tokenizer=tokenizer, device=device)

def extract_concepts(text):
    pattern = r'(.*?)(<\|.*?\|>)(.*)'
    concepts = []
    while text:
        match = re.match(pattern, text, re.DOTALL)
        if match:
            if match.group(1):
                concepts.append(match.group(1))
            concepts.append(match.group(2))
            text = match.group(3)
        else:
            if text:
                concepts.append(text)
            break
    return concepts

def concepts_to_regex(concepts):
    regex_parts = []
    
    for concept in concepts:
        if concept == '<|infill_ngram|>':
            regex_parts.append(r"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*")
        elif concept == '<|infill_word|>':
            regex_parts.append(r"[ ]?[a-zA-Z0-9'.!,?]+[ ]?")
        elif concept == '<|infill_sentence|>':
            regex_parts.append(r"[a-zA-Z0-9', ]+[.!?]")
        else:
            escaped_text = re.escape(concept)
            regex_parts.append(escaped_text)
    
    regex_pattern = ''.join(regex_parts)
    return regex_pattern

def generate_infilled_text(prompt):
    concepts = extract_concepts(prompt)
    spans = []
    try:
        len_initial_prompt = len(prompt)
        for concept in concepts:
            if concept == '<|infill_ngram|>':
                regex_parts = r"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*" # + tokenizer.eos_token
                max_new_tokens = 9
            elif concept == '<|infill_word|>':
                regex_parts = r"[ ]?[a-zA-Z0-9'.!,?]+[ ]?" # + tokenizer.eos_token
                max_new_tokens = 4
            elif concept == '<|infill_sentence|>':
                regex_parts = r"[a-zA-Z0-9', ]+[.!?]" # + tokenizer.eos_token
                max_new_tokens = 16
            else:
                prompt += concept
                continue
            span = gpt2 + f"{prompt}{gen(max_tokens=max_new_tokens, regex=regex_parts)}"
            full_text = str(span)
            generated_text = full_text[len(prompt):].strip()
            prompt += generated_text
            spans.append(generated_text)
        return prompt[len_initial_prompt:], spans
    except Exception as e:
        print(e)
        return prompt

def process_files(file_paths):
    for file_path in file_paths:
        print(f"Processing {file_path}...")
        
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Error reading file {file_path}: {str(e)}")
            continue
        
        output_dict = {}
        for item in data:
            try:
                #print(item["story"])
                generated_story, spans = generate_infilled_text(item["story"])
                generated_text = f"{item['title']}\n{generated_story}"
                
                output_dict[str(item["id"])] = {
                    "title": item["title"],
                    "generated": generated_text,
                    "spans": spans
                }

                from pprint import pprint
                pprint(output_dict)
                
            except Exception as e:
                print(f"Error processing item {item['id']}: {str(e)}")
                output_dict[str(item["id"])] = {
                    "title": item["title"],
                    "generated": f"{item['title']}\n{item['story']}"
                }
        
        base_name = os.path.splitext(os.path.basename(file_path))[0]
        output_file = f"{base_name}_guidance.json"
        
        with open(output_file, 'w') as f:
            json.dump(output_dict, f, indent=2)
        
        print(f"Saved output to {output_file}")

if __name__ == "__main__":
    file_list = []
    process_files(file_list)